{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using `catboost.metrics` module\n",
    "\n",
    "In this short tutorial, we'll show you the benefits of using the metrics from `catboost.metrics`.\n",
    "\n",
    "`catboost.metrics` module contains classes for all metrics (or loss functions) available for use with Catboost. By importing it, you can view the list of metrics, their supported parameters and their default values. Formulas and descriptions for the metrics are available in documentation: https://catboost.ai/docs/concepts/loss-functions.html."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from catboost import metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['AUC',\n",
       " 'AUCMulticlass',\n",
       " 'Accuracy',\n",
       " 'AverageGain',\n",
       " 'BalancedAccuracy',\n",
       " 'BalancedErrorRate',\n",
       " 'BrierScore',\n",
       " 'BuiltinMetric',\n",
       " 'Combination',\n",
       " 'CrossEntropy',\n",
       " 'CtrFactor',\n",
       " 'DCG',\n",
       " 'DataFrame',\n",
       " 'ERR',\n",
       " 'Expectile',\n",
       " 'F1',\n",
       " 'FairLoss',\n",
       " 'FilteredDCG',\n",
       " 'HammingLoss',\n",
       " 'HingeLoss',\n",
       " 'Huber',\n",
       " 'Kappa',\n",
       " 'LogLikelihoodOfPrediction',\n",
       " 'LogLinQuantile',\n",
       " 'Logloss',\n",
       " 'Lq',\n",
       " 'MAE',\n",
       " 'MAP',\n",
       " 'MAPE',\n",
       " 'MCC',\n",
       " 'MRR',\n",
       " 'MSLE',\n",
       " 'MedianAbsoluteError',\n",
       " 'MultiClass',\n",
       " 'MultiClassOneVsAll',\n",
       " 'MultiRMSE',\n",
       " 'NDCG',\n",
       " 'NormalizedGini',\n",
       " 'NormalizedGiniMulticlass',\n",
       " 'NumErrors',\n",
       " 'PFound',\n",
       " 'PRAUC',\n",
       " 'PRAUCMulticlass',\n",
       " 'PairAccuracy',\n",
       " 'PairLogit',\n",
       " 'Poisson',\n",
       " 'Precision',\n",
       " 'PrecisionAt',\n",
       " 'Quantile',\n",
       " 'QueryAUC',\n",
       " 'QueryAUCMulticlass',\n",
       " 'QueryAverage',\n",
       " 'QueryCrossEntropy',\n",
       " 'QueryRMSE',\n",
       " 'QuerySoftMax',\n",
       " 'R2',\n",
       " 'RMSE',\n",
       " 'RMSEWithUncertainty',\n",
       " 'Recall',\n",
       " 'RecallAt',\n",
       " 'SMAPE',\n",
       " 'Series',\n",
       " 'TotalF1',\n",
       " 'Tweedie',\n",
       " 'UserQuerywiseMetric',\n",
       " 'WKappa',\n",
       " 'ZeroOneLoss',\n",
       " '_ARRAY_TYPES',\n",
       " '_MetricGenerator',\n",
       " '__all__',\n",
       " '__builtins__',\n",
       " '__cached__',\n",
       " '__doc__',\n",
       " '__file__',\n",
       " '__loader__',\n",
       " '__name__',\n",
       " '__package__',\n",
       " '__spec__',\n",
       " '_catboost',\n",
       " '_current_params',\n",
       " '_del_param',\n",
       " '_generate_metric_classes',\n",
       " '_get_param',\n",
       " '_set_param',\n",
       " '_to_string',\n",
       " 'copy',\n",
       " 'np',\n",
       " 'partial',\n",
       " 'sys']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dir(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now you can inspect the metrics: for example, `Lq` has one mandatory parameter `q` and parameters `use_weights` with default value `True` and no `hints`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Builtin metric: 'Lq'\n",
       "Parameters:\n",
       "    q (mandatory)\n",
       "    use_weights = True (default value)\n",
       "    hints = '' (default value)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics.Lq"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In contrast, `AUC` metric has no mandatory parameters, and its computation is disabled on train by default:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Builtin metric: 'AUC'\n",
       "Parameters:\n",
       "    type = 'Classic' (default value)\n",
       "    hints = 'skip_train~true' (default value)\n",
       "    use_weights = False (default value)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics.AUC"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To use the metric, you can first create it by supplying the required parameters. The constructor will warn you if you forget something or make a typo:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Parameter q is mandatory and must be specified.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-39f8f2b18b7d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLq\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/Desktop/YDS/MLPrac/catboost/catboost/python-package/catboost/metrics.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(cls, **kwargs)\u001b[0m\n\u001b[1;32m    174\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_set\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparam_is_set\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    175\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_set\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 176\u001b[0;31m                 \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Parameter {} is mandatory and must be specified.'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    177\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    178\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Parameter q is mandatory and must be specified."
     ]
    }
   ],
   "source": [
    "m = metrics.Lq()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Unexpected parameter Q",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-6-435b9f63fcec>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLq\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mQ\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/Desktop/YDS/MLPrac/catboost/catboost/python-package/catboost/metrics.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(cls, **kwargs)\u001b[0m\n\u001b[1;32m    167\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    168\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mparam\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_valid_params\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 169\u001b[0;31m                 \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Unexpected parameter {}'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    170\u001b[0m             \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    171\u001b[0m             \u001b[0mparam_is_set\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Unexpected parameter Q"
     ]
    }
   ],
   "source": [
    "m = metrics.Lq(Q=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Lq(q=4 [mandatory=True], use_weights=True [mandatory=False])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# successful creation of a metric:\n",
    "m = metrics.Lq(q=4)\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Metric classes have some helper methods, such as:\n",
    "* `set_hints(**kwargs)` – for setting hints,\n",
    "* `params_with_defaults()` – for programmatic access to parameter information,\n",
    "* `eval(labels, approxes, <other params: weight, group_id, group_weight, subgroup_id, pairs, thread_count>)` – to evaluate the metric on given true labels and predicted labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Lq(hints='skip_train~true' [mandatory=False], q=4 [mandatory=True], use_weights=True [mandatory=False])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m.set_hints(skip_train=True)\n",
    "m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m.q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'q': {'default_value': None, 'is_mandatory': True},\n",
       " 'use_weights': {'default_value': True, 'is_mandatory': False},\n",
       " 'hints': {'default_value': '', 'is_mandatory': False}}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m.params_with_defaults()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now use this metric in training or evaluation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0:\ttotal: 50.8ms\tremaining: 203ms\n",
      "1:\ttotal: 51.5ms\tremaining: 77.3ms\n",
      "2:\ttotal: 51.9ms\tremaining: 34.6ms\n",
      "3:\ttotal: 52.3ms\tremaining: 13.1ms\n",
      "4:\ttotal: 52.8ms\tremaining: 0us\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<catboost.core.CatBoostRegressor at 0x7ff4b3376790>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from catboost import CatBoostRegressor\n",
    "\n",
    "# Initialize toy data\n",
    "train_data = [[1, 4, 5, 6],\n",
    "              [4, 5, 6, 7],\n",
    "              [30, 40, 50, 60]]\n",
    "\n",
    "eval_data = [[2, 4, 6, 8],\n",
    "             [1, 4, 50, 60]]\n",
    "\n",
    "train_labels = [10, 20, 30]\n",
    "\n",
    "eval_labels = [12, 15]\n",
    "\n",
    "# Initialize CatBoostRegressor\n",
    "model = CatBoostRegressor(\n",
    "    iterations=5,\n",
    "    learning_rate=1,\n",
    "    depth=2,\n",
    "    loss_function=m,\n",
    "    custom_metric=[metrics.Lq(q=5), metrics.R2()]\n",
    ")\n",
    "\n",
    "# Fit model\n",
    "model.fit(train_data, train_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'learn': {'Lq:q=5': 388.4841340753533, 'R2': 0.8778546529631635}}\n"
     ]
    }
   ],
   "source": [
    "print(model.get_best_score())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get predictions\n",
    "preds = model.predict(eval_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[5.107940380990622]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics.RMSE().eval(eval_labels, preds)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
